上回我們收集完2021年台塑股票資訊了,接下來我們就要開始建構我們的RNN模型了。我們先把主要功能建立成函式,這樣的好處是最後的程式碼看起來比較乾淨、可讀性提升,之後要修改程式碼也比較容易。我們將主要功能分為三項,分別是匯入資料、建立模型以及訓練模型。
第一個函式為匯入資料,參數為資料、標籤、訓練長度以及訓練集與測試集的比例。首先要做的事跟做CNN時一樣,先將資料轉成合適的檔案格式,接著將特徵值正規化。接著我們將每十天的收盤價、最高價以及最低價為一單位做儲存,而標籤結果為隔一天的收盤價。最後跟CNN一樣,需要分為訓練集以及測試集,最後的結果才能作驗證。
def load_data(df, dfp, sequence_length = 10, split = 0.8):
# convert features from dataframe to float
data_all = np.array(df).astype(float)
# scale features betweem 0 and 1
data_all = scaler.fit_transform(data_all)
# convert labels from dataframe to float
datap_all = np.array(dfp).astype(float)
print("datap_all's shape: ", datap_all.shape)
# scale labels between 0 and 1
datap_all = scalert.fit_transform(datap_all)
# split every sequence_length's days into a sector, label is the day after the sector
data = []
datap = []
for i in range(len(data_all) - sequence_length):
data.append(data_all[i: i + sequence_length])
datap.append(datap_all[i + sequence_length])
# convert features and labels from list into float matrix
x = np.array(data).astype("float64")
y = np.array(datap).astype("float64")
# split training set and testing set
split_boundary = int(x.shape[0] * split)
train_x = x[:split_boundary]
test_x = x[split_boundary:]
train_y = y[:split_boundary]
test_y = y[split_boundary:]
return train_x, train_y, test_x, test_y
第二個函式為建立模型,不需要給予參數,最終回傳model。我們這裡選用LSTM模型,這樣對於長期的訓練效果比較好。
def build_model():
# create a model
model = Sequential()
# add LSTM layer
model.add(LSTM(input_shape = (10, 3), units = 256, unroll = False))
# add input layer
model.add(Dense(units = 1))
# set the training mode
model.compile(loss = "mse", optimizer = "adam", metrics = ["accuracy"])
return model
第三個函式為訓練模型,參數為訓練集資料、訓練集標籤、測試集資料。功能為用給定的訓練集資料以及訓練集標籤做模型的訓練,接著以訓練完的模型對測試集資料做預測,最終回傳預測結果。
def train_model(train_x, train_y, test_x, test_y):
# start training model
model.fit(train_x, train_y, batch_size = 100, epochs = 300, validation_split = 0.1)
# do prediction on test set
predict = model.predict(test_x)
# convert to one-dimension matrix
predict = np.reshape(predict, (predict.size,))
return predict